Fastai2 is a Layered API

What does that mean?

Let's look at an image example, first with the High Level API, we will use:

  • ImageDataLoader
  • cnn_learner

Resources used for Demo:

Importing vision module

General rule of thumb: fastai.x

  • Vision
  • Tabular
  • Colab
  • Text
In [1]:
from fastai.vision.all import *

Downloading the datasets, we will be looking at the PETs Dataset

In [2]:
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

We need to tell the lables our model:

Inside PETS dataset, Cats have the first letter as upper. So we write label_func()

In [3]:
def label_func(f): return f[0].isupper()

DataLoaders

  • As the name suggests, DataLoaders are a way of feeding data to our models.
  • xDataLoader: x can be Application supported by fastai
  • Further, DataLoaders have different helper functions to allowing feeding data in most common formats
In [4]:
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(224))

Let's take a look at the first few images:

In [5]:
dls.show_batch()

Dogs >>>Cats. I said it 😁

Creating a Model

Let's create our Machine Learning Model! More accurately, a CNN.

We'll use a cnn_learner for our task

In [6]:
learn = cnn_learner(dls, resnet34, metrics=error_rate)
learn.fine_tune(1)
epoch train_loss valid_loss error_rate time
0 0.162312 0.018972 0.005413 00:17
epoch train_loss valid_loss error_rate time
0 0.048161 0.013592 0.003383 00:21

Looking at results:

In [7]:
learn.show_results()

Taking training wheels off: DataBlocks API

  • "Mid-level" API, you need to define:
    • blocks: to define the problem, categorical and image
    • splitter: How to set valid-train splits
    • How to get labels
    • transforms (More on these later)
In [8]:
pets = DataBlock(blocks=(ImageBlock, CategoryBlock), 
                 get_items=get_image_files, 
                 splitter=RandomSplitter(),
                 get_y=using_attr(RegexLabeller(r'(.+)_\d+.jpg$'), 'name'),
                 item_tfms=Resize(460),
                 batch_tfms=aug_transforms(size=224))

Now we're ready to load all of these into a dataloader and show the first batch

In [9]:
dls = pets.dataloaders(untar_data(URLs.PETS)/"images")
In [10]:
dls.show_batch(max_n=9)

Multi-Lable Classification

The orignal task of the dataset is object detection, let's try to use it for image classification example

First, let's look at the high-level API:

In [64]:
path = untar_data(URLs.PASCAL_2007)
path.ls()
Out[64]:
(#8) [Path('/storage/data/pascal_2007/train.csv'),Path('/storage/data/pascal_2007/valid.json'),Path('/storage/data/pascal_2007/test.json'),Path('/storage/data/pascal_2007/test.csv'),Path('/storage/data/pascal_2007/test'),Path('/storage/data/pascal_2007/train'),Path('/storage/data/pascal_2007/train.json'),Path('/storage/data/pascal_2007/segmentation')]
In [65]:
df = pd.read_csv(path/'train.csv')
df.head()
Out[65]:
fname labels is_valid
0 000005.jpg chair True
1 000007.jpg car True
2 000009.jpg horse person True
3 000012.jpg car False
4 000016.jpg bicycle True
In [66]:
dls = ImageDataLoaders.from_df(df, path, folder='train', valid_col='is_valid', label_delim=' ',
                               item_tfms=Resize(460), batch_tfms=aug_transforms(size=224))
In [67]:
dls.show_batch()
In [15]:
learn = cnn_learner(dls, resnet50, metrics=partial(accuracy_multi, thresh=0.5))
In [16]:
learn.fine_tune(4, 3e-2)
epoch train_loss valid_loss accuracy_multi time
0 0.431439 0.137803 0.958088 00:23
epoch train_loss valid_loss accuracy_multi time
0 0.160930 0.344820 0.929502 00:25
1 0.182072 0.222219 0.930856 00:25
2 0.159581 0.120245 0.956115 00:25
3 0.129626 0.106977 0.961155 00:25
In [17]:
learn.show_results()
In [18]:
learn.predict(path/'train/000005.jpg')
Out[18]:
((#1) ['chair'],
 tensor([False, False, False, False, False, False, False, False,  True, False,
         False, False, False, False, False, False, False, False, False, False]),
 tensor([2.5324e-04, 2.4844e-03, 1.5633e-04, 4.1936e-04, 4.7598e-02, 7.2696e-04,
         1.2475e-03, 1.2292e-02, 9.2737e-01, 6.3124e-05, 3.2658e-01, 5.5052e-03,
         1.7077e-04, 6.3773e-04, 6.7024e-02, 1.6096e-01, 3.8893e-05, 1.3953e-01,
         2.7315e-04, 4.5196e-01]))
In [19]:
interp = Interpretation.from_learner(learn)
interp.plot_top_losses(9)
target predicted probabilities loss
0 car;person;train car tensor([1.2728e-05, 1.2785e-03, 8.7742e-05, 2.1169e-03, 3.8385e-05, 1.5920e-03,\n 9.9974e-01, 4.5645e-06, 4.5770e-05, 4.4169e-06, 8.3377e-07, 3.0897e-06,\n 4.5148e-07, 9.0592e-05, 1.6643e-02, 1.7029e-03, 2.8882e-07, 1.7818e-04,\n 1.3530e-05, 5.2989e-06]) 0.765691876411438
1 chair;diningtable;person person;train tensor([5.7383e-05, 1.1879e-02, 6.5730e-04, 1.4720e-03, 3.9942e-04, 9.2216e-02,\n 4.6772e-01, 2.8637e-04, 3.0233e-03, 7.8617e-04, 1.0003e-03, 5.1493e-05,\n 1.3683e-03, 2.6933e-02, 8.6169e-01, 1.1305e-02, 1.1546e-04, 5.1263e-04,\n 5.5786e-01, 7.0504e-04]) 0.7229101657867432
2 dog;pottedplant;sofa;tvmonitor cat tensor([5.0364e-05, 2.0883e-03, 3.9267e-03, 6.3734e-04, 1.5945e-02, 1.3643e-04,\n 9.7530e-04, 7.5495e-01, 2.8518e-01, 2.2989e-03, 1.7217e-02, 2.4789e-01,\n 9.3987e-04, 4.1881e-04, 2.2814e-02, 2.5245e-02, 1.1538e-03, 2.0396e-01,\n 1.7490e-03, 9.7245e-03]) 0.6554890275001526
3 car;person;tvmonitor car tensor([2.7852e-05, 4.2245e-03, 1.7206e-04, 5.3499e-03, 6.6080e-04, 4.6662e-04,\n 9.9932e-01, 2.6080e-05, 3.3175e-04, 1.1131e-05, 2.2962e-05, 2.0592e-05,\n 2.1709e-06, 3.5251e-03, 1.9651e-01, 4.8418e-03, 1.4130e-06, 4.4774e-04,\n 8.2517e-06, 3.4723e-05]) 0.5958005785942078
4 chair;dog;person;pottedplant;sofa;tvmonitor person;sofa tensor([7.0382e-05, 1.2801e-03, 2.9246e-03, 4.1490e-04, 1.1424e-02, 1.7606e-04,\n 3.6302e-03, 1.6316e-02, 2.8765e-01, 7.9041e-04, 1.2357e-02, 8.0666e-02,\n 6.2717e-04, 5.0731e-04, 9.9437e-01, 6.4058e-02, 1.2751e-03, 9.4651e-01,\n 1.3236e-04, 7.2852e-03]) 0.5773086547851562
5 dog;person;sofa person tensor([5.8521e-04, 1.9271e-03, 3.6062e-02, 9.3929e-04, 5.2378e-02, 4.8707e-03,\n 6.1535e-03, 4.0810e-03, 3.2776e-02, 4.2985e-04, 8.1750e-03, 4.9449e-03,\n 6.2204e-04, 1.5876e-03, 9.9790e-01, 1.3130e-02, 4.9901e-04, 2.4019e-03,\n 8.2174e-04, 9.5609e-03]) 0.576021671295166
6 chair;dog;person;pottedplant person tensor([1.7474e-04, 1.3529e-03, 1.2851e-01, 1.8144e-03, 5.0327e-03, 3.5432e-03,\n 4.1505e-03, 6.6433e-04, 2.4494e-02, 4.4907e-04, 2.4489e-03, 1.6330e-03,\n 2.5702e-03, 3.1727e-03, 9.9905e-01, 3.9563e-01, 6.6959e-04, 7.8110e-04,\n 1.1828e-04, 7.4699e-03]) 0.561348557472229
7 bottle;chair;pottedplant;sofa;tvmonitor chair;person;sofa tensor([2.5779e-04, 3.1850e-03, 9.8818e-04, 5.8826e-04, 2.3997e-02, 8.5256e-04,\n 3.8403e-03, 1.4846e-02, 6.8454e-01, 3.7760e-04, 8.6851e-02, 2.9018e-02,\n 5.1904e-04, 2.3399e-03, 9.4869e-01, 1.9272e-01, 4.4384e-04, 7.8736e-01,\n 1.5667e-04, 1.3009e-01]) 0.5576278567314148
8 bus;person car tensor([2.2211e-04, 2.8686e-03, 1.1502e-04, 2.7054e-03, 5.9515e-04, 6.0934e-02,\n 9.9416e-01, 1.0554e-04, 4.8675e-04, 1.4888e-05, 4.8082e-05, 2.7571e-05,\n 7.9668e-06, 3.0507e-04, 5.5840e-02, 2.3926e-03, 2.2317e-06, 6.2846e-04,\n 8.5646e-04, 6.2537e-04]) 0.5419460535049438

Mid-Level API

Let's Recap:

The DataBlock API needs:

  • Blocks
  • Spliiters
  • Methods to grab train and lables
  • item and batch transforms
In [21]:
pascal = DataBlock(blocks=(ImageBlock, MultiCategoryBlock),
                   splitter=ColSplitter('is_valid'),
                   get_x=ColReader('fname', pref=str(path/'train') + os.path.sep),
                   get_y=ColReader('labels', label_delim=' '),
                   item_tfms = Resize(460),
                   batch_tfms=aug_transforms(size=224))
In [22]:
dls = pascal.dataloaders(df)
In [23]:
dls.show_batch(max_n=9)

Text

The goal of showcasing this is to demonstrate the similarity between different applications

In [69]:
from fastai.text.all import *

Loading the IMDB Dataset, we'll try to classify Movie-reviews

In [70]:
path = untar_data(URLs.IMDB)
path.ls()
Out[70]:
(#7) [Path('/storage/data/imdb/README'),Path('/storage/data/imdb/tmp_lm'),Path('/storage/data/imdb/imdb.vocab'),Path('/storage/data/imdb/tmp_clas'),Path('/storage/data/imdb/test'),Path('/storage/data/imdb/train'),Path('/storage/data/imdb/unsup')]

Setting our DataLoader

xDataLoaders, x=Text

In [72]:
dls = TextDataLoaders.from_folder(untar_data(URLs.IMDB), valid='test')

We'll need to create a Learner for Text Classifcation, aptly named text_classifier_learner

In [74]:
learn = text_classifier_learner(dls, AWD_LSTM, drop_mult=0.5, metrics=accuracy)

Creating a LSTM Model

In [75]:
learn.fine_tune(4, 1e-2)
epoch train_loss valid_loss accuracy time
0 0.604561 0.395645 0.826840 01:52
epoch train_loss valid_loss accuracy time
0 0.314829 0.326577 0.862520 03:31
1 0.241916 0.229754 0.911320 03:37
2 0.167968 0.208609 0.922080 03:31
3 0.171125 0.194013 0.924800 03:23
In [76]:
learn.fine_tune(4, 1e-2)
epoch train_loss valid_loss accuracy time
0 0.179995 0.217087 0.920200 01:55
epoch train_loss valid_loss accuracy time
0 0.165780 0.204109 0.923240 03:38
1 0.144565 0.217666 0.921160 03:32
2 0.114676 0.216971 0.926680 03:31
3 0.099624 0.207423 0.929880 03:36

Surprisingly, this model isn't very bad!

In [78]:
learn.predict("I really liked that movie!")
Out[78]:
('pos', tensor(1), tensor([0.0015, 0.9985]))

DataBlocks API

  • Blocks: Text, Category
  • Item Types
  • Grabbing Parent Labels
  • Splitter
In [80]:
imdb = DataBlock(blocks=(TextBlock.from_folder(path), CategoryBlock),
                 get_items=get_text_files,
                 get_y=parent_label,
                 splitter=GrandparentSplitter(valid_name='test'))
In [81]:
dls = imdb.dataloaders(path)

Tabular Demo

Let's look at a Tabular Dataset to further extrapolate the similarities

In [87]:
from fastai.tabular.all import *

path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
In [88]:
dls = TabularDataLoaders.from_csv(path/'adult.csv', path=path, y_names="salary",
    cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race'],
    cont_names = ['age', 'fnlwgt', 'education-num'],
    procs = [Categorify, FillMissing, Normalize])
In [89]:
splits = RandomSplitter(valid_pct=0.2)(range_of(df))
In [91]:
to = TabularPandas(df, procs=[Categorify, FillMissing,Normalize],
                   cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race'],
                   cont_names = ['age', 'fnlwgt', 'education-num'],
                   y_names='salary',
                   splits=splits)
In [92]:
dls = to.dataloaders(bs=64)
In [93]:
learn = tabular_learner(dls, metrics=accuracy)
learn.fit_one_cycle(1)
epoch train_loss valid_loss accuracy time
0 0.343695 0.362684 0.830620 00:06

Image Augmentations

  • Flip
  • DiHedral
  • CropPad
  • Squish
  • RatioResize

GPU Transforms:

  • RandomResizedCropGPU
  • Flip
  • DiHedral
  • Zoom
  • RandomCentres
  • Warping (Vertical Vs Horizontal)
  • Brightness
  • Contrast
  • Grayscale
  • Saturation
  • HSV2RGB
  • Hue
  • RandomErasing

Setup:

In this walkthrough, we will look at two images:

  • Our reference Doggy
  • Suggested Use-case

Let's load the Doggy image into memory and take a look at it:

In [2]:
path = untar_data(URLs.PETS)
DOGGY = path/'images/beagle_1.jpg'
In [3]:
DOGGY
Out[3]:
Path('/storage/data/oxford-iiit-pet/images/beagle_1.jpg')
In [5]:
_MED = Path("./")
MED = _MED/"med.jpg"
In [6]:
_SAT = Path("./")
SAT = _SAT/"sat.jpg"
In [7]:
_BOX = Path("./")
BOX = _BOX/"box.jpg"

Resize

In [8]:
img = PILImage(PILImage.create(DOGGY).resize((600,400)))
In [9]:
med = PILImage(PILImage.create(MED).resize((600,400)))
In [10]:
sat = PILImage(PILImage.create(SAT).resize((600,400)))
In [11]:
box = PILImage(PILImage.create(BOX).resize((600,400)))

Horizontal Flip

In [12]:
_,axs = subplots(1,2)
show_image(img, ctx=axs[0], title='original')
show_image(img.flip_lr(), ctx=axs[1], title='flipped');
In [13]:
_,axs = subplots(1,2)
show_image(sat, ctx=axs[0], title='original')
show_image(sat.flip_lr(), ctx=axs[1], title='flipped');

Dihedral Flip

In [28]:
_,axs = subplots(2, 4)
for ax in axs.flatten():
    show_image(DihedralItem(p=1.)(img, split_idx=0), ctx=ax)
In [130]:
_,axs = subplots(2, 4)
for ax in axs.flatten():
    show_image(DihedralItem(p=1.)(sat, split_idx=0), ctx=ax)

CropPad

In [29]:
_,axs = plt.subplots(1,3,figsize=(12,4))
for ax,sz in zip(axs.flatten(), [300, 500, 700]):
    show_image(img.crop_pad(sz), ctx=ax, title=f'Size {sz}');
In [131]:
_,axs = plt.subplots(1,3,figsize=(12,4))
for ax,sz in zip(axs.flatten(), [300, 500, 700]):
    show_image(med.crop_pad(sz), ctx=ax, title=f'Size {sz}');

Resize with Options

  • Zeroes
  • Border
  • Reflection
In [30]:
_,axs = plt.subplots(1,3,figsize=(12,4))
for ax,mode in zip(axs.flatten(), [PadMode.Zeros, PadMode.Border, PadMode.Reflection]):
    show_image(img.crop_pad((600,700), pad_mode=mode), ctx=ax, title=mode);
In [132]:
_,axs = plt.subplots(1,3,figsize=(12,4))
for ax,mode in zip(axs.flatten(), [PadMode.Zeros, PadMode.Border, PadMode.Reflection]):
    show_image(med.crop_pad((600,700), pad_mode=mode), ctx=ax, title=mode);

RandomCrop

In [31]:
_,axs = plt.subplots(1,3,figsize=(12,4))
f = RandomCrop(200)
for ax in axs: show_image(f(img), ctx=ax);
In [133]:
_,axs = plt.subplots(1,3,figsize=(12,4))
f = RandomCrop(200)
for ax in axs: show_image(f(med), ctx=ax);

Resize: Squish, Pad, Crop

In [33]:
_,axs = plt.subplots(1,3,figsize=(12,4))
for ax,method in zip(axs.flatten(), [ResizeMethod.Squish, ResizeMethod.Pad, ResizeMethod.Crop]):
    rsz = Resize(256, method=method)
    show_image(rsz(img, split_idx=0), ctx=ax, title=method);

Let's take an Image Search Engine Example:

In [14]:
_,axs = plt.subplots(1,3,figsize=(12,4))
for ax,method in zip(axs.flatten(), [ResizeMethod.Squish, ResizeMethod.Pad, ResizeMethod.Crop]):
    rsz = Resize(256, method=method)
    show_image(rsz(box, split_idx=0), ctx=ax, title=method);

RandomResizedCrop

In [35]:
crop = RandomResizedCrop(256)
_,axs = plt.subplots(3,3,figsize=(9,9))
for ax in axs.flatten():
    cropped = crop(img)
    show_image(cropped, ctx=ax);

Let's look at the OCR Example:

In [15]:
crop = RandomResizedCrop(256)
_,axs = plt.subplots(3,3,figsize=(9,9))
for ax in axs.flatten():
    cropped = crop(med)
    show_image(cropped, ctx=ax);

Ratio Resize

In [37]:
RatioResize(1024)(img)
Out[37]:

GPU Transforms

In [17]:
timg = TensorImage(array(img)).permute(2,0,1).float()/255.
def _batch_ex(bs): return TensorImage(timg[None].expand(bs, *timg.shape).clone())
In [27]:
timg_sat = TensorImage(array(sat)).permute(2,0,1).float()/255.
def sat_batch_ex(bs): return TensorImage(timg_sat[None].expand(bs, *timg_sat.shape).clone())
In [31]:
timg_med = TensorImage(array(med)).permute(2,0,1).float()/255.
def med_batch_ex(bs): return TensorImage(timg_med[None].expand(bs, *timg_med.shape).clone())
In [35]:
timg_box = TensorImage(array(box)).permute(2,0,1).float()/255.
def box_batch_ex(bs): return TensorImage(timg_box[None].expand(bs, *timg_box.shape).clone())

RandomResizedCropGPU

In [39]:
t = _batch_ex(8)
rrc = RandomResizedCropGPU(224, p=1.)
y = rrc(t)
_,axs = plt.subplots(2,4, figsize=(12,6))
for ax in axs.flatten():
    show_image(y[2], ctx=ax)
In [40]:
x = flip_mat(torch.randn(100,4,3))
test_eq(set(x[:,0,0].numpy()), {-1,1}) #might fail with probability 2*2**(-100) (picked only 1s or -1s)
In [41]:
dih = DeterministicFlip({'p':.3})

DeterministicFlip

In [19]:
t = _batch_ex(8)
dih = DeterministicFlip()
_,axs = plt.subplots(2,4, figsize=(12,6))
for i,ax in enumerate(axs.flatten()):
    y = dih(t)
    show_image(y[0], ctx=ax, title=f'Call {i}')
In [26]:
t_ = box_batch_ex(8)
dih = DeterministicFlip()
_,axs = plt.subplots(2,4, figsize=(12,6))
for i,ax in enumerate(axs.flatten()):
    y = dih(t_)
    show_image(y[0], ctx=ax, title=f'Call {i}')

Dihedral

In [43]:
t = _batch_ex(8)
dih = Dihedral(p=1., draw=list(range(8)))
y = dih(t)
y = t.dihedral_batch(p=1., draw=list(range(8)))
_,axs = plt.subplots(2,4, figsize=(12,5))
for i,ax in enumerate(axs.flatten()): show_image(y[i], ctx=ax, title=f'Flip {i}')
In [28]:
t = sat_batch_ex(8)
dih = Dihedral(p=1., draw=list(range(8)))
y = dih(t)
y = t.dihedral_batch(p=1., draw=list(range(8)))
_,axs = plt.subplots(2,4, figsize=(12,5))
for i,ax in enumerate(axs.flatten()): show_image(y[i], ctx=ax, title=f'Flip {i}')

DeterministicDihedral

In [140]:
t = _batch_ex(8)
dih = DeterministicDihedral()
_,axs = plt.subplots(2,4, figsize=(12,6))
for i,ax in enumerate(axs.flatten()):
    y = dih(t)
    show_image(y[0], ctx=ax, title=f'Call {i}')
In [29]:
t = sat_batch_ex(8)
dih = DeterministicDihedral()
_,axs = plt.subplots(2,4, figsize=(12,6))
for i,ax in enumerate(axs.flatten()):
    y = dih(t)
    show_image(y[0], ctx=ax, title=f'Call {i}')

Rotation

In [45]:
thetas = [-30,-15,0,15,30]
y = _batch_ex(5).rotate(draw=thetas, p=1.)
_,axs = plt.subplots(1,5, figsize=(15,3))
for i,ax in enumerate(axs.flatten()):
    show_image(y[i], ctx=ax, title=f'{thetas[i]} degrees')

Let's rewind to the sattelite image:

In [30]:
thetas = [-30,-15,0,15,30]
y = sat_batch_ex(5).rotate(draw=thetas, p=1.)
_,axs = plt.subplots(1,5, figsize=(15,3))
for i,ax in enumerate(axs.flatten()):
    show_image(y[i], ctx=ax, title=f'{thetas[i]} degrees')

Scale

In [46]:
scales = [0.8, 1., 1.1, 1.25, 1.5]
n = len(scales)
y = _batch_ex(n).zoom(draw=scales, p=1., draw_x=0.5, draw_y=0.5)
fig,axs = plt.subplots(1, n, figsize=(12,3))
fig.suptitle('Center zoom with different scales')
for i,ax in enumerate(axs.flatten()):
    show_image(y[i], ctx=ax, title=f'scale {scales[i]}')

This comes in hand when we're trying to zoom into an image

In [32]:
scales = [0.8, 1., 1.1, 1.25, 1.5]
n = len(scales)
y = med_batch_ex(n).zoom(draw=scales, p=1., draw_x=0.5, draw_y=0.5)
fig,axs = plt.subplots(1, n, figsize=(12,3))
fig.suptitle('Center zoom with different scales')
for i,ax in enumerate(axs.flatten()):
    show_image(y[i], ctx=ax, title=f'scale {scales[i]}')

Constant Scale and Random Centres

In [47]:
y = _batch_ex(4).zoom(p=1., draw=1.5)
fig,axs = plt.subplots(1,4, figsize=(12,3))
fig.suptitle('Constant scale and different random centers')
for i,ax in enumerate(axs.flatten()):
    show_image(y[i], ctx=ax)

Going back to the OCR Example:

In [33]:
y = med_batch_ex(4).zoom(p=1., draw=1.5)
fig,axs = plt.subplots(1,4, figsize=(12,3))
fig.suptitle('Constant scale and different random centers')
for i,ax in enumerate(axs.flatten()):
    show_image(y[i], ctx=ax)

Warp

In [48]:
scales = [-0.4, -0.2, 0., 0.2, 0.4]
warp = Warp(p=1., draw_y=scales, draw_x=0.)
y = warp(_batch_ex(5), split_idx=0)
fig,axs = plt.subplots(1,5, figsize=(15,3))
fig.suptitle('Vertical warping')
for i,ax in enumerate(axs.flatten()):
    show_image(y[i], ctx=ax, title=f'magnitude {scales[i]}')
In [36]:
scales = [-0.4, -0.2, 0., 0.2, 0.4]
warp = Warp(p=1., draw_y=scales, draw_x=0.)
y = warp(box_batch_ex(5), split_idx=0)
fig,axs = plt.subplots(1,5, figsize=(15,3))
fig.suptitle('Vertical warping')
for i,ax in enumerate(axs.flatten()):
    show_image(y[i], ctx=ax, title=f'magnitude {scales[i]}')
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
In [49]:
scales = [-0.4, -0.2, 0., 0.2, 0.4]
warp = Warp(p=1., draw_x=scales, draw_y=0.)
y = warp(_batch_ex(5), split_idx=0)
fig,axs = plt.subplots(1,5, figsize=(15,3))
fig.suptitle('Horizontal warping')
for i,ax in enumerate(axs.flatten()):
    show_image(y[i], ctx=ax, title=f'magnitude {scales[i]}')
In [37]:
scales = [-0.4, -0.2, 0., 0.2, 0.4]
warp = Warp(p=1., draw_x=scales, draw_y=0.)
y = warp(box_batch_ex(5), split_idx=0)
fig,axs = plt.subplots(1,5, figsize=(15,3))
fig.suptitle('Horizontal warping')
for i,ax in enumerate(axs.flatten()):
    show_image(y[i], ctx=ax, title=f'magnitude {scales[i]}')
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Brightness

In [50]:
scales = [0.1, 0.3, 0.5, 0.7, 0.9]
y = _batch_ex(5).brightness(draw=scales, p=1.)
fig,axs = plt.subplots(1,5, figsize=(15,3))
for i,ax in enumerate(axs.flatten()):
    show_image(y[i], ctx=ax, title=f'scale {scales[i]}')
In [38]:
scales = [0.1, 0.3, 0.5, 0.7, 0.9]
y = med_batch_ex(5).brightness(draw=scales, p=1.)
fig,axs = plt.subplots(1,5, figsize=(15,3))
for i,ax in enumerate(axs.flatten()):
    show_image(y[i], ctx=ax, title=f'scale {scales[i]}')

Contrast

In [51]:
scales = [0.65, 0.8, 1., 1.25, 1.55]
y = _batch_ex(5).contrast(p=1., draw=scales)
fig,axs = plt.subplots(1,5, figsize=(15,3))
for i,ax in enumerate(axs.flatten()): show_image(y[i], ctx=ax, title=f'scale {scales[i]}')
In [39]:
scales = [0.65, 0.8, 1., 1.25, 1.55]
y = med_batch_ex(5).contrast(p=1., draw=scales)
fig,axs = plt.subplots(1,5, figsize=(15,3))
for i,ax in enumerate(axs.flatten()): show_image(y[i], ctx=ax, title=f'scale {scales[i]}')

Saturation

In [52]:
scales = [0., 0.5, 1., 1.5, 2.0]
y = _batch_ex(5).saturation(p=1., draw=scales)
fig,axs = plt.subplots(1,5, figsize=(15,3))
for i,ax in enumerate(axs.flatten()): show_image(y[i], ctx=ax, title=f'scale {scales[i]}')
In [40]:
scales = [0., 0.5, 1., 1.5, 2.0]
y = med_batch_ex(5).saturation(p=1., draw=scales)
fig,axs = plt.subplots(1,5, figsize=(15,3))
for i,ax in enumerate(axs.flatten()): show_image(y[i], ctx=ax, title=f'scale {scales[i]}')
In [53]:
fig,axs=plt.subplots(figsize=(20, 4),ncols=5)
axs[0].set_ylabel('Hue')
for ax in axs:
    ax.set_xlabel('Saturation')
    ax.set_yticklabels([])
    ax.set_xticklabels([])

hsv=torch.stack([torch.arange(0,2.1,0.01)[:,None].repeat(1,210),
                 torch.arange(0,1.05,0.005)[None].repeat(210,1),
                 torch.ones([210,210])])[None]
for ax,i in zip(axs,range(0,5)):
    if i>0: hsv[:,2].mul_(0.80)
    ax.set_title('V='+'%.1f' %0.8**i)
    ax.imshow(hsv2rgb(hsv)[0].permute(1,2,0))
In [54]:
scales = [0.5, 0.75, 1., 1.5, 1.75]
y = _batch_ex(len(scales)).hue(p=1., draw=scales)
fig,axs = plt.subplots(1,len(scales), figsize=(15,3))
for i,ax in enumerate(axs.flatten()): show_image(y[i], ctx=ax, title=f'scale {scales[i]}')
In [41]:
scales = [0.5, 0.75, 1., 1.5, 1.75]
y = sat_batch_ex(len(scales)).hue(p=1., draw=scales)
fig,axs = plt.subplots(1,len(scales), figsize=(15,3))
for i,ax in enumerate(axs.flatten()): show_image(y[i], ctx=ax, title=f'scale {scales[i]}')

Cutout

In [55]:
nrm = Normalize.from_stats(*imagenet_stats, cuda=False)
In [56]:
f = partial(cutout_gaussian, areas=[(100,200,100,200),(200,300,200,300)])
show_image(norm_apply_denorm(timg, f, nrm)[0]);

Random Erasing

In [57]:
tfm = RandomErasing(p=1., max_count=6)

_,axs = subplots(2,3, figsize=(12,6))
f = partial(tfm, split_idx=0)
for i,ax in enumerate(axs.flatten()): show_image(norm_apply_denorm(timg, f, nrm)[0], ctx=ax)
In [58]:
y = _batch_ex(6)
_,axs = plt.subplots(2,3, figsize=(12,6))
y = norm_apply_denorm(y, f, nrm)
for i,ax in enumerate(axs.flatten()): show_image(y[i], ctx=ax)

Zero Padding

In [60]:
tfms = aug_transforms(pad_mode='zeros', mult=2, min_scale=0.5)
y = _batch_ex(9)
for t in tfms: y = t(y, split_idx=0)
_,axs = plt.subplots(1,3, figsize=(12,3))
for i,ax in enumerate(axs.flatten()): show_image(y[i], ctx=ax)

TTA: Test Time Augmentation

A Learner object allows you to call a tta() method like so:

In [62]:
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(224))
learn = cnn_learner(dls, resnet34, metrics=error_rate)
learn.fine_tune(1)
epoch train_loss valid_loss error_rate time
0 0.132741 0.016442 0.004736 00:18
epoch train_loss valid_loss error_rate time
0 0.054450 0.013252 0.005413 00:22
In [63]:
learn.tta()
Out[63]:
(tensor([[1.0000e+00, 1.5211e-07],
         [1.0000e+00, 4.8723e-06],
         [9.8221e-01, 1.7787e-02],
         ...,
         [1.0000e+00, 1.6213e-06],
         [9.9998e-01, 1.5963e-05],
         [1.2467e-10, 1.0000e+00]]),
 tensor([0, 0, 0,  ..., 0, 0, 1]))

Finish